clear

% Load data----------------------------------------------------------------
load('Results/output.mat');                             % Stage-1 data
load('Results/muHistory');                              % Stage-1 parameter estimates
theta = muHistory(end,:);
quota = csvread('Bear Quota Average.csv');              % Quota
% sigma = sigma(:,1:K);                                   % Observed application shares
load('Results/Vbsave.mat')
load('Results/P_zeta.mat')
pit = [thetastar(4:5) (1 - sum(thetastar(4:5)))];
impossible = phi == 0;

% Calculate WTP under baseline --------------------------------------------
run = 0;
if run ~= 0 
    

    if exist('outsigma0save.mat','file') ~= 0 && exist('Np0save.mat','file') ~= 0
        load('outsigma0save.mat');
        load('Np0save.mat');
    else
        PTC = zeros(Nu,1);
        for idx1 = 1:Nu
            PTC(idx1) = sum(tcidx==idx1)/size(tcidx,1);
        end

        PpTC = zeros(Nu,K);
        for idx2 = 1:Nu
            for idx3 = 1:K
                PpTC(idx2,idx3) = zeta(idx3,1,idx2,1)*pit(1) ...
                    + zeta(idx3,1,idx2,2)*pit(2) + zeta(idx3,1,idx2,3)*pit(3);
            end
        end

        Pp = PpTC'*PTC;
        
        impCt = impossible./repmat(sum(impossible),[H+1,1]);
        for idx1 = 1:tau
            for idx2 = 1:Nu
                y = P(:,:,idx2,idx1);
                y(impossible) = y(impossible).*impCt(impossible);
                P(:,:,idx2,idx1) = y;
            end 
        end


        % Calculate estimated shares of applicants with k preference points that visit each site
        x = zeros(H,K,Nu,tau);
        for idx1 = 1:tau
            for idx2 = 1:H
                for idx3 = 1:K
                    for idx4 = 1:Nu
                        x(idx2,idx3,idx4,idx1) = P(idx2+1,idx3,idx4,idx1)...
                            *zeta(idx3,1,idx4,idx1)*pit(idx1)*PTC(idx4)/Pp(idx3);
                    end
                end
            end
        end

        x = sum(x,4);
        % x = x{1} + x{2} + x{3};
        outsigma0 = sum(x,3);
        outsigma0 = [1 - sum(outsigma0); outsigma0];
        
        Np = Pp'*55454; 
        
    end
    
    if exist('Results/Vbsave0.mat')~=0
        load('Results/Vbsave0.mat');
    end
    
    diff = 2e-6;

    while diff > 1e-6
    
        insigma = outsigma0;

		a = .7;

        [Np,P0,phi0,PTC,outsigma0,R0,Vb0,Vbsave,zeta0] = dreumWTP(chunksize,H,...
            K,Np,Nu,quota,rho,insigma,tau,tcidx,tcSplit,theta,tolInner,a,Vbsave);
        diff = max(max(abs(insigma - outsigma0)));
    
        fprintf('Max difference: %e\n',diff) % Print output for debugging
        save('outsigma0save.mat','outsigma0');
        save('Np0save.mat','Np');
        save('Vbsave0.mat','Vbsave');
    
        if diff > tolInner
            outsigma0 = a*insigma + outsigma0*(1-a);
        end   
    
    end

    save('sigma0save.mat')

else

    load('sigma0save.mat')
%    phi0 = phi;
%    impossible0 = phi0 == 0;
%    Vb0 = dreumInner5(H,impossible0(2:end,:),K,Nu,phi0,PMarkov,rho,tau,tc,theta,tolInner);
%    pit = [1 - pi(1) - pi(2), pi(1), pi(2)];
%    [~,~,~,PTC,~,zeta0] = dreumStationary_6(alpha,data,H,impossible0,K,Nu,phi0,pit,tau,tcidx,Vb0);
        
end


% Drop a BMU from set of choice alternatives, calculate new eq-------------

% Bergland = 1:3
% Baraga = 4:6
% Amasa = 7:9
% Carney = 10:12
% Gwinn = 13:15
% Newberry = 16:18
% Drummond Island = 19
% Red Oak = 20
% Baldwin = 21
% Gladwin = 22

drop = 16:18;

% Drop closed sites from data 
quota1 = quota;
quota1(drop) = 0;
outsigma1 = outsigma0;
load('Np0save.mat');
Np1 = Np;

Vbsave1 = Vbsave;

diff = 2*1e-6;

while diff > 1e-6
    
    insigma = outsigma1;
    
     a = .5;
     
    [Np1,P1,phi1,PTC,outsigma1,R1,Vb1,Vbsave1,zeta1] = dreumWTP(chunksize,H,...
            K,Np1,Nu,quota1,rho,insigma,tau,tcidx,tcSplit,theta,tolInner,a,Vbsave1);
%     [Np,P1,phi1,~,outsigma1,Vb1,zeta1] = dreumWTP(alpha,chi0,data,eta,H1,K,mu,...
%         Np,Nu,pi,quota,rho,insigma,tau,tc,tcidx,tolInner,a);
    
    diff = max(max(abs(insigma - outsigma1)));
    
    fprintf('Max difference: %e\n',diff) % Print output 
    save('Vbsave1.mat','Vbsave1');
    save('outsigma1save.mat','outsigma1');
    save('Np1save.mat','Np1');
    
    if diff > tolInner
        outsigma1 = a*insigma + outsigma1*(1-a);
    end   
    
end

save('sigma1save.mat')

% Calculate WTP0-----------------------------------------------------------
Y0 = zeros(Nu,K,tau); 
mu = thetastar(1);
for idx2 = 1:tau
    for idx1 = 1:Nu
        for idx3 = 1:K
            
            if idx3 < K
%                 Y0(idx1,idx3,idx2) = R0(idx3,1,idx1,idx2) - rho*sum(P0(:,idx3,idx1,idx2).*(phi0(:,idx3)*R0(1,1,idx1,idx2) ...
%                     +(1-phi0(:,idx3)).*R0(idx3+1,1,idx1,idx2)));
                Y0(idx1,idx3,idx2) = R0(idx3,1,idx1,idx2) - rho*(sum(P0(1:H,idx3,idx1,idx2).*(phi0(1:H,idx3)*R0(1,1,idx1,idx2) ...
                    +(1-phi0(1:H,idx3)).*R0(idx3+1,1,idx1,idx2))) + P0(H+1,idx3,idx1,idx2)*R0(idx3,1,idx1,idx2));
            else
%                 Y0(idx1,idx3,idx2) = R0(idx3,1,idx1,idx2) - rho*sum(P0(:,idx3,idx1,idx2).*(phi0(:,idx3)*R0(1,1,idx1,idx2) ...
%                     +(1-phi0(:,idx3)).*R0(idx3,1,idx1,idx2)));
                Y0(idx1,idx3,idx2) = R0(idx3,1,idx1,idx2) - rho*(sum(P0(1:H,idx3,idx1,idx2).*(phi0(1:H,idx3)*R0(1,1,idx1,idx2) ...
                    +(1-phi0(1:H,idx3)).*R0(idx3,1,idx1,idx2))) + P0(H+1,idx3,idx1,idx2)*R0(idx3,1,idx1,idx2));
            end
            
        end
        
        WTP0(idx1,idx2) = Y0(idx1,:,idx2)/abs(mu/200)*zeta0(:,1,idx1,idx2)*pit(idx2)*PTC(idx1);
        
    end
end


% Calculate WTP1-----------------------------------------------------------
Y1 = zeros(Nu,K,tau); 
for idx2 = 1:tau
    for idx1 = 1:Nu
        for idx3 = 1:K
            
            if idx3 < K
%                 Y1(idx1,idx3,idx2) = R1(idx3,1,idx1,idx2) - rho*sum(P1(:,idx3,idx1,idx2).*(phi1(:,idx3)*R1(1,1,idx1,idx2) ...
%                     +(1-phi1(:,idx3)).*R1(idx3+1,1,idx1,idx2)));
                Y1(idx1,idx3,idx2) = R1(idx3,1,idx1,idx2) - rho*(sum(P1(1:H,idx3,idx1,idx2).*(phi1(1:H,idx3)*R1(1,1,idx1,idx2) ...
                    +(1-phi1(1:H,idx3)).*R1(idx3+1,1,idx1,idx2))) + P1(end,idx3,idx1,idx2)*R1(idx3,1,idx1,idx2));
            else
%                 Y1(idx1,idx3,idx2) = R1(idx3,1,idx1,idx2) - rho*sum(P1(:,idx3,idx1,idx2).*(phi1(:,idx3)*R1(1,1,idx1,idx2) ...
%                     +(1-phi1(:,idx3)).*R1(idx3,1,idx1,idx2)));
                Y1(idx1,idx3,idx2) = R1(idx3,1,idx1,idx2) - rho*(sum(P1(1:H,idx3,idx1,idx2).*(phi1(1:H,idx3)*R1(1,1,idx1,idx2) ...
                    +(1-phi1(1:H,idx3)).*R1(idx3,1,idx1,idx2))) + P1(end,idx3,idx1,idx2)*R1(idx3,1,idx1,idx2));
            end
            
        end
        
        WTP1(idx1,idx2) = Y1(idx1,:,idx2)/abs(mu/200)*zeta1(:,1,idx1,idx2)*pit(idx2)*PTC(idx1);
        
    end
end

WTPmean = sum(sum(WTP1,2) - sum(WTP0,2)); 